from argparse import ArgumentParser

import pytorch_lightning.loggers as pl_loggers
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin

from dataset import PredRNNDataModule, PredRNNAWSDataModule, UNetDataModule, UNetAWSDataModule
from models import PredRNN, PredRNN_AWS, Unet_aws, Unet


def main(args):
    dict_args = vars(args)

    if args.model_name == "predrnn_aws":
        model = PredRNN_AWS(**dict_args)
        data = PredRNNAWSDataModule(**dict_args)
        flag = False
    elif args.model_name == "unet":
        model = UNetDataModule(**dict_args)
        data = Unet(**dict_args)
        flag = False
    elif args.model_name == "unet_aws":
        model = UNetAWSDataModule(**dict_args)
        data = Unet_aws(**dict_args)
        flag = False
    elif args.model_name == "predrnn":
        model = PredRNN(**dict_args)
        data = PredRNNDataModule(**dict_args)
        flag = False
    else:
        print("model name is error!")

    save_callback = ModelCheckpoint(
        monitor=args.save_monitor,
        dirpath=args.save_dirpath,
        filename=args.save_filename,
        save_top_k=args.save_top_k,
        mode=args.save_mode,
    )

    tb_logger = pl_loggers.TensorBoardLogger(save_dir=args.tensorboard_save_path,
                                             name=args.tensorboard_exp_name)
    callbacks = [save_callback, ]

    trainer = Trainer(logger=tb_logger,
                      accelerator="ddp",
                      plugins=DDPPlugin(find_unused_parameters=flag),
                      gpus=args.gpus,
                      val_check_interval=args.check_val_rate,
                      callbacks=callbacks,
                      progress_bar_refresh_rate=args.refresh_rate,
                      max_epochs=args.max_epochs,
                      resume_from_checkpoint=args.resume_path)

    trainer.fit(model, datamodule=data)

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--gpus", type=str, default="2,3")
    parser.add_argument("--tensorboard_save_path", type=str, default=None)
    parser.add_argument("--tensorboard_exp_name", type=str, default=None)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_works", type=int, default=8)
    parser.add_argument("--pin_memory",  type=int, default=0)
    parser.add_argument("--max_epochs", type=int, default=500)
    parser.add_argument("--train_file", type=str, default="")
    parser.add_argument("--val_file", type=str, default="")
    parser.add_argument("--test_file", type=str, default="")
    parser.add_argument("--root_dir", type=str, default="")
    parser.add_argument("--save_monitor", type=str, default="")
    parser.add_argument("--model_name", type=str, default="")
    parser.add_argument("--save_dirpath", type=str, default=None)
    parser.add_argument("--save_filename", type=str, default='weights-{epoch:03d}-{valid_loss_fx:.3f}', )
    parser.add_argument("--save_top_k", type=int, default=50)
    parser.add_argument("--save_mode", type=str, default="min")
    parser.add_argument("--refresh_rate", type=int, default=20)
    parser.add_argument("--check_val_rate", type=float, default=0.5)
    parser.add_argument("--resume_path", type=str, default=None)


    temp_args, _ = parser.parse_known_args()
    # let the model add what it wants
    if temp_args.model_name == 'predrnn_aws':
        parser = PredRNN_AWS.add_model_specific_args(parser)
    elif temp_args.model_name == "unet3d":
        model = Unet.add_model_specific_args(parser)
    elif temp_args.model_name == "unet3d_aws":
        model = Unet_aws.add_model_specific_args(parser)
    elif temp_args.model_name == "predrnn":
        model = PredRNN.add_model_specific_args(parser)
    else:
        pass
    args = parser.parse_args()
    main(args)
